Auto Encoder#

AE; Auto Encoder, 自己符号化器


オートエンコーダのイメージ
出典:オートエンコーダ– ニューラルネットワーク・DeepLearningなどの画像素材 プレゼン・ゼミなどに【WTFPL】

Hide code cell source
# packageのimport
import math 
from datetime import datetime
from typing import Any, Union, Callable, Type, TypeVar
from tqdm.auto import trange,tqdm
import numpy as np 
import numpy.typing as npt
import pandas as pd 
import matplotlib.pyplot as plt 
import plotly.express as px
import seaborn as sns
plt.style.use("bmh")
from sklearn.manifold import TSNE

# pytorch関連のimport
import torch
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim 
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms

# animation
from matplotlib.animation import ArtistAnimation
from IPython.display import display, Markdown, HTML
Hide code cell output
/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
print("pytorch ver.:",torch.__version__)
print("numpy ver.:",np.__version__)
print("Apple Siliconが使える:", torch.backends.mps.is_available())
print("CUDAが使える:", torch.cuda.is_available())
!python -VV
pytorch ver.: 2.0.1
numpy ver.: 1.23.3
Apple Siliconが使える: True
CUDAが使える: False
Python 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:38:29) [Clang 13.0.1 ]

Auto Encoderとは#

このノートにおける AutoEncoder(以降 AE[Hinton and Salakhutdinov, 2006]とは,ニューラル自己符号化器と呼ばれるものを指します.私たちは単にこれをAEと呼び,入力されたデータベクトル(テンソル)を,より次元の小さいベクトルに圧縮した後に,そのベクトルを使って元のデータを再構築するようなニューラルネットワークを指します.

例えば28x28x1の画像データは,flattenすると784個の要素があることがわかりますね.これを50次元のベクトルに圧縮するニューラルネットワークを用意し,更に圧縮表現から元の784要素に戻すニューラルネットワークを作るのが今回の試みです.個別のネットワークとしてもいいですし,一気通貫に圧縮と展開を担うニューラルネットワークを作っても良いでしょう.このようにデータを圧縮するネットワーク(モジュール)を Encoder(エンコーダー) ,圧縮されたベクトルを元にデータを再構築するのを Decoder(デコーダー) と呼びます.

また,Encoderはデータ点を元あった空間から潜在的な変数空間へマップする作業をしていると考えられるので,圧縮されたベクトルを 潜在ベクトル などと呼ぶこともあります.

このようにあるデータを,元の次元数より小さい次元数のベクトルなどに圧縮する技術のことを,次元削減(Dimensionality reduction, dimension reduction) と呼びます.

AutoEncoderのアーキテクチャと実装#

最もシンプルなAE#

実装#

最も簡単なAEの構成は,全結合層と適切な活性化関数を用いて,特徴数→圧縮次元数→特徴数の順に変換する三層のニューラルネットワークです.
ここでは入力されるデータベクトルの各要素が0~1の範囲の値であるとして考えていきます.そのため,出力層の活性化関数はSigmoidにしてあります.また,pytorchにおいて画像データは一枚一枚が[チャネル, 縦, 横]の配列で表現されるので,バッチごとにデータを入力するならば[バッチサイズ, チャネル, 縦, 横]の配列になります.これを入力層からデータを入れるときには,これまでのMLPを思い出せば[バッチサイズ, チャネル\(\times\)\(\times\) 横]の配列にreshapeする必要がありそうです.

この前提でネットワークを組むと以下のようになります.

class SimpleAE(nn.Module):
    def __init__(self, in_features:int, n_components:int):
        super().__init__()
        self.in_features = in_features
        self.n_components = n_components
        # build layers
        self.encoder = nn.Sequential(
            nn.Linear(self.in_features, self.n_components),
            nn.Sigmoid(),
        )
        self.decoder = nn.Sequential(
            nn.Linear(self.n_components, self.in_features),
            nn.Sigmoid(),
        )

    def forward(self, x:torch.Tensor):
        h = self.encoder(x)
        return self.decoder(h)
    
ae = SimpleAE(10,2)
display(ae)
SimpleAE(
  (encoder): Sequential(
    (0): Linear(in_features=10, out_features=2, bias=True)
    (1): Sigmoid()
  )
  (decoder): Sequential(
    (0): Linear(in_features=2, out_features=10, bias=True)
    (1): Sigmoid()
  )
)

わかりやすいように,Encoder部分とDecoder部分をそれぞれnn.Sequentialインスタンスとしています.また,このモデルの損失関数には再構成誤差を用います.ここでは具体例として二乗和誤差を使うことにします.

チューニング#

この構成で良い結果が得られない場合は以下のことを検討すると良いでしょう.

  1. バッチノーマライゼーションの追加

  2. 活性化関数を変更

  3. より深層化

  4. 正則化項の追加

Weight Tying/ Weight Sharing#

Weight Tyingを使った線形変換レイヤの実装#

「最もシンプルなAE」から更に,パラメータ数を減らしてみましょう.ここで使うのが Weight TyingWeight Sharing と呼ばれるテクニックです.これはいくつかのレイヤーで結合重みを共有するというアプローチです.膨大な特徴量を受け取るようなnn.Linearクラスの重みを使い回すことでメモリを節約できます.ここではEncoderの全結合層とDecoderの全結合層の結合重みを共有します.

class WeightTyingLinear(nn.Module):
    """ほぼnn.Linearと同じで,結合重みだけ別のnn.Linearインスタンスを利用するクラス"""
    def __init__(self, shared_weights:torch.Tensor,bias:bool=True, 
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()

        self.weight = shared_weights.T
        self.out_features,self.in_features = self.weight.size()
        if bias:
            self.bias = nn.Parameter(torch.empty((self.out_features), **factory_kwargs)) 
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
    
    def reset_parameters(self) -> None:
        # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
        # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
        # https://github.com/pytorch/pytorch/issues/57109
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.linear(input, self.weight, self.bias)
    
    def extra_repr(self) -> str: 
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )

重みの受け取りとbiasの初期化#

コンストラクタでshared_weightsとしてencoder側の重みを受け取り,これを使ってforwardメソッドで全結合層の処理を行います.biasは共有できない(encoderとdecoderではshapeが違います)ので,コンストラクタの中で初期化しています.

  • nn.Parameterクラス

    • nn.Moduleのコンストラクタで学習可能パラメータを用意するためのクラスです.

  • reset_parametersメソッド

    • nn.Linearから借用したメソッドです.self.biasはとりあえずtorch.emptyで初期化してありますが,これはモダンなbiasの初期化方法ではないのでreset_parametersメソッドで改めて初期化を行っています.

    • 具体的にはHeの初期化[He et al., 2015] を行います.

extra_reprメソッド#

extra_reprメソッドには「このクラスのインスタンスを直接printしたときに表示したい文字列」を返り値として設定してあげます.今回の実装だと例えば:

WeightTyingLinear(in_features=2, out_features=10, bias=True)

のような表示がされるはずです.今回はnn.Linearクラスの実装を参考にしているので,自作レイヤを作る際には公式実装を読んで書き方を掴んでください.

Auto Encoderクラスの実装#

AEクラスは前述のものとほぼ同じです.

class WeightTyingAE(nn.Module):
    def __init__(self, in_features:int, n_components:int):
        super().__init__()
        self.in_features = in_features
        self.n_components = n_components
        # build layers
        self.encoder = nn.Sequential(
            nn.Linear(self.in_features, self.n_components),
            nn.Sigmoid(),
        )
        self.decoder = nn.Sequential(
            WeightTyingLinear(self.encoder[0].weight),
            nn.Sigmoid(),
        )

    def forward(self, x:torch.Tensor):
        h = self.encoder(x)
        return self.decoder(h)

本当にパラメータ共有ができているのかチェックしてみましょう.

ae2 = WeightTyingAE(10,2)
display(ae2)
print("パラメータが共有されているのかをチェック:")
for ix, params in enumerate(ae2.parameters()):
    print("--"*30)
    print(f">>> params_{ix}:")
    print(params)
    print(params.shape)
WeightTyingAE(
  (encoder): Sequential(
    (0): Linear(in_features=10, out_features=2, bias=True)
    (1): Sigmoid()
  )
  (decoder): Sequential(
    (0): WeightTyingLinear(in_features=2, out_features=10, bias=True)
    (1): Sigmoid()
  )
)
パラメータが共有されているのかをチェック:
------------------------------------------------------------
>>> params_0:
Parameter containing:
tensor([[-0.0668,  0.0970,  0.2833,  0.0253, -0.5156,  0.5497, -0.0550,  0.1866,
         -0.6601,  0.2233],
        [ 0.4639, -0.1422, -0.7014,  0.5481,  0.4235,  0.1275,  0.4139, -0.6170,
          0.0686,  0.0313]], requires_grad=True)
torch.Size([2, 10])
------------------------------------------------------------
>>> params_1:
Parameter containing:
tensor([-0.1974,  0.2932], requires_grad=True)
torch.Size([2])
------------------------------------------------------------
>>> params_2:
Parameter containing:
tensor([ 0.6214, -0.5454,  0.1737,  0.4014, -0.1944,  0.4143, -0.5110,  0.4357,
         0.0943, -0.0366], requires_grad=True)
torch.Size([10])

パラメータは4つの配列のはずが,この実装ではWeight Tyingによって一つ削減できていることがわかります.

画像の形状を変えずにネットワーククラスに入力する#

入力データの配列をできるだけそのまま(配列の形状を変更せずに[バッチサイズ, チャネル, 縦, 横]のまま)入力することを考えてみましょう.[バッチサイズ, チャネル, 縦, 横]を[バッチサイズ, チャネル\(\times\)\(\times\) 横]へ変形するための機能は,通常の行列計算ライブラリならflatten関数/メソッドかviewメソッド,またはreshapeメソッドで提供されています.

a = torch.zeros([3,1,2,2])
print("配列aの形状:", a.shape)
print("a=", a)

print("---"*20)
print("flattenメソッド:")
a1 = a.flatten(start_dim=1)
print("flattenを適用した配列aの形状:", a1.shape)
print("a=",a1)

print("---"*20)
print("viewメソッド:")
a2 = a.view((a.shape[0], -1))
print("viewを適用した配列aの形状:", a2.shape)
print("a=",a2)

print("---"*20)
print("reshapeメソッド:")
a3 = a.reshape((a.shape[0], -1))
print("reshapeを適用した配列aの形状:", a3.shape)
print("a=",a3)
配列aの形状: torch.Size([3, 1, 2, 2])
a= tensor([[[[0., 0.],
          [0., 0.]]],


        [[[0., 0.],
          [0., 0.]]],


        [[[0., 0.],
          [0., 0.]]]])
------------------------------------------------------------
flattenメソッド:
flattenを適用した配列aの形状: torch.Size([3, 4])
a= tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])
------------------------------------------------------------
viewメソッド:
viewを適用した配列aの形状: torch.Size([3, 4])
a= tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])
------------------------------------------------------------
reshapeメソッド:
reshapeを適用した配列aの形状: torch.Size([3, 4])
a= tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])

これらのメソッドのどれかをネットワーククラスのforwardクラスの最初に入力配列に対して適用し,出力層の出力をreshapeで入力配列と同じ形状に直すことで,ネットワークの外側で配列の形状を変える必要がなくなります.

ただし,nn.Sequentialの中に並べるには,nn.Moduleのサブクラスである必要があります.flattenはnn.Flattenというクラスがあるのでこれを使えばいいのですが,Reshapeは対応するクラスがありません.

Note

「nn.Sequentialの中に並べるには,nn.Moduleのサブクラスである必要があります.」
例えばnn.Sigmoidの代わりにF.sigmoidをnn.Sequentialに渡した場合:

nn.Sequential(
    nn.Linear(10,20), 
    F.sigmoid
    )

以下のエラーが発生:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[384], line 1
----> 1 nn.Sequential(
      2     nn.Linear(10,20), 
      3     F.sigmoid
      4     )

File ~/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/torch/nn/modules/container.py:104, in Sequential.__init__(self, *args)
    102 else:
    103     for idx, module in enumerate(args):
--> 104         self.add_module(str(idx), module)

File ~/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/torch/nn/modules/module.py:596, in Module.add_module(self, name, module)
    586 r"""Adds a child module to the current module.
    587 
    588 The module can be accessed as an attribute using the given name.
   (...)
    593     module (Module): child module to be added to the module.
    594 """
    595 if not isinstance(module, Module) and module is not None:
--> 596     raise TypeError("{} is not a Module subclass".format(
    597         torch.typename(module)))
    598 elif not isinstance(name, str):
    599     raise TypeError("module name should be a string. Got {}".format(
    600         torch.typename(name)))

TypeError: torch.nn.functional.sigmoid is not a Module subclass

こういうときには,nn.Moduleを継承したサブクラスを作成し,その中でreshapeに相当する関数を使うことになります.例えば以下のような形になります:

class Reshape(nn.Module):
    def __init__(self, img_shape:tuple[int]) -> None:
        super().__init__()
        self.img_shape = img_shape

    def forward(self,x:torch.Tensor)->torch.Tensor:
        return x.view((x.shape[0], *self.img_shape))
    
    def extra_repr(self) -> str:
        return 'in_features={}, out_img_shape={}'.format(
            np.prod(self.img_shape),self.img_shape,
        )
    
class SimpleAE(nn.Module):
    def __init__(self, img_shape:tuple[int], n_components:int):
        super().__init__()
        self.img_shape = img_shape
        self.in_features = np.prod(img_shape)
        self.n_components = n_components
        # build layers
        self.encoder = nn.Sequential(
            nn.Flatten(), # diff (batch_size, 1, 28, 28)->(batch_size, 1x28x28)
            nn.Linear(self.in_features, self.n_components),
            nn.Sigmoid(),
        )
        self.decoder = nn.Sequential(
            nn.Linear(self.n_components, self.in_features),
            nn.Sigmoid(),
            Reshape(self.img_shape) # diff (batch_size, 1x28x28)->(batch_size, 1, 28, 28)
        )

    def forward(self, x:torch.Tensor):
        h = self.encoder(x)
        return self.decoder(h)

実験#

MNISTデータセット[Bottou et al., 1994]を圧縮する実験を行います.

データの準備#

MNISTのデータローダを用意しておきます.関数はCNNのところで使ったものと同じです.

Hide code cell source
def load_MNIST(batch=128):
    transform = transforms.Compose([
        transforms.ToTensor(),
        #transforms.Normalize((0.1307,), (0.3081,)),
        #transforms.Lambda(lambda x: torch.flatten(x))
        ])

    train_set = torchvision.datasets.MNIST(root="./data",
                                           train=True,
                                           download=True,
                                           transform=transform)
    # train_validation_split.
    # see also https://qiita.com/takurooo/items/ba8c509eaab080e2752c#%E8%A7%A3%E6%B1%BA%E7%AD%962-torchutilsdatarandom_split.
    n_samples = len(train_set) # n_samples is 60000
    train_size = int(len(train_set) * 0.8) # train_size is 48000
    val_size = n_samples - train_size # val_size is 48000
    train_set, val_set = torch.utils.data.random_split(train_set, [train_size, val_size])

    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=batch,
                                               shuffle=True,
                                               num_workers=2)
    val_loader = torch.utils.data.DataLoader(val_set,
                                               batch_size=batch,
                                               shuffle=True,
                                               num_workers=2)
    

    test_set = torchvision.datasets.MNIST(root="./data",
                                         train=False,
                                         download=True,
                                         transform=transform)
    test_loader =torch.utils.data.DataLoader(test_set,
                                            batch_size=batch,
                                            shuffle=True,
                                            num_workers=2)

    return {"train":train_loader, "validation":val_loader, "test": test_loader}

skorchを使う場合はtraining/validation splitをskorchがやってくれるので,自分で行う必要がありません.

def load_MNIST_skorch(batch=128):
    transform = transforms.Compose([
        transforms.ToTensor(),
        #transforms.Normalize((0.1307,), (0.3081,)),
        #transforms.Lambda(lambda x: torch.flatten(x))
        ])

    train_set = torchvision.datasets.MNIST(root="./data",
                                           train=True,
                                           download=True,
                                           transform=transform)

    test_set = torchvision.datasets.MNIST(root="./data",
                                         train=False,
                                         download=True,
                                         transform=transform)
    return {"train_dataset":train_set, "test_dataset": test_set}

一応中身を確認しておきましょう.

# 中身を確認しておく.今回は適当に40個取り出してみます.
_data_loader = load_MNIST(batch=1)
_train_loader= _data_loader["train"]

fig = plt.figure(figsize=[12,7])
for ix, batch in enumerate(_train_loader):
    img,y = batch
    ax = fig.add_subplot(4,10,ix+1)
    ax.imshow(img.view(-1,28), cmap='gray')
    if ix == 39:
        break
fig.show()
/var/folders/zn/y387lxqx5zx7j4t_3b0v0bg80000gn/T/ipykernel_92986/962318129.py:12: UserWarning: Matplotlib is currently using module://matplotlib_inline.backend_inline, which is a non-GUI backend, so cannot show the figure.
  fig.show()
_images/ae8b1d697f9afc473e74aee57c38eaca5dabeddf2e342ca5a49d79529891f1c2.png

訓練#

今回は訓練の進み具合を確認するために,1 epochごとに事前に選んでおいた画像を再構成させてみます.そのため,0~9の数字の手書き文字画像をランダムに一枚ずつ取り出す関数を用意します.

Hide code cell source
def get_sample():
    transform = transforms.Compose([
        transforms.ToTensor(),])
    test_set = torchvision.datasets.MNIST(root="./data",
                                         train=False,
                                         download=True,
                                         transform=transform)
    test_loader =torch.utils.data.DataLoader(test_set,
                                            batch_size=1,
                                            shuffle=True,
                                            num_workers=2)
    sample = []
    i = 0
    for x,y in test_loader:
        if y == i:
            sample.append(x)
            i+=1
        if len(sample) == 10:
            break
    return sample

学習スクリプトはほぼいつも通りですが,学習過程の可視化をするために追加された部分があるので注意してください.

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
LATENT_DIM = 20
LEARNING_RATE = 0.01
MAX_EPOCHS = 50
BATCH_SIZE = 128
data_loader = load_MNIST(batch=BATCH_SIZE)
model = SimpleAE(np.prod(img.size()), LATENT_DIM).to(DEVICE)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)


#学習結果の保存
history = {
    "train_loss": [],
    "validation_loss": [],
    "validation_acc": [],
    "frames": [],
}

"""学習過程の可視化"""
# 訓練中に利用する画像を用意しておく
sample_imgs = get_sample()

# プロットの雛形を用意しておく
fig_anime = plt.figure(figsize=[15,5])
axs = []
for i in range(20):
    axs.append(fig_anime.add_subplot(2,10,i+1))
"""/学習過程の可視化"""

for epoch in trange(MAX_EPOCHS):
    # training step
    loss = None
    train_loss = 0.0
    model.train()
    for i,(x,y) in enumerate(data_loader["train"]):
        x,y = x.to(DEVICE),y.to(DEVICE)
        x = x.view(x.shape[0], -1)
        # 勾配の初期化
        optimizer.zero_grad()
        # 順伝搬 -> 逆伝搬 -> 最適化
        x_hat = model(x)
        loss = criterion(x_hat, x)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
    train_loss /= len(data_loader["train"])
    history["train_loss"].append(float(train_loss))

    """validation step"""
    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for x,y in data_loader["validation"]:
            x,y = x.to(DEVICE),y.to(DEVICE)
            x = x.view(x.shape[0], -1)
            x_hat = model(x)
            val_loss = criterion(x_hat,x)

    val_loss /= len(data_loader["validation"])
    #print("Validation loss: {}\n".format(val_loss))
    history["validation_loss"].append(float(val_loss))
    
    """学習過程の可視化"""
    model.eval()
    with torch.no_grad():
        history["frames"].append([])
        for i in range(10):
            art1 = axs[i].imshow(sample_imgs[i].view(-1,28).detach().cpu().numpy(),
                          cmap='gray')
            sample_hat = model(sample_imgs[i].view([1,-1]).to(DEVICE))
            history["frames"][-1].append(art1)
            art2 = axs[i+10].imshow(sample_hat.view(-1,28).detach().cpu().numpy(),
                             cmap="gray")
            history["frames"][-1].append(art2)
    """/学習過程の可視化"""
    
plt.close(fig_anime) # 余計な図を表示させないようにする
Hide code cell output
  0%|                                                                                                                                                                                                                           | 0/50 [00:00<?, ?it/s]
  2%|████▏                                                                                                                                                                                                              | 1/50 [00:03<02:44,  3.36s/it]
  4%|████████▍                                                                                                                                                                                                          | 2/50 [00:06<02:45,  3.44s/it]
  6%|████████████▋                                                                                                                                                                                                      | 3/50 [00:10<02:52,  3.67s/it]
  8%|████████████████▉                                                                                                                                                                                                  | 4/50 [00:14<02:44,  3.57s/it]
 10%|█████████████████████                                                                                                                                                                                              | 5/50 [00:18<02:45,  3.67s/it]
 12%|█████████████████████████▎                                                                                                                                                                                         | 6/50 [00:21<02:36,  3.56s/it]
 14%|█████████████████████████████▌                                                                                                                                                                                     | 7/50 [00:24<02:29,  3.48s/it]
 16%|█████████████████████████████████▊                                                                                                                                                                                 | 8/50 [00:27<02:22,  3.40s/it]
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/multiprocessing/spawn.py", line 126, in _main

 16%|█████████████████████████████████▊                                                                                                                                                                                 | 8/50 [00:28<02:30,  3.58s/it]    exitcode = _main(fd, parent_sentinel)
  File "/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
  File "/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/torch/__init__.py", line 1247, in <module>
    self = reduction.pickle.load(from_parent)
  File "/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/torch/__init__.py", line 1247, in <module>
    import torch.backends.mps
  File "/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/torch/backends/mps/__init__.py", line 30, in <module>
    import torch.backends.mps
  File "/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/torch/backends/mps/__init__.py", line 30, in <module>
    from ..._refs import var_mean as _var_mean, native_group_norm as _native_group_norm
  File "/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/torch/_refs/__init__.py", line 14, in <module>
    from ..._refs i
mport var_mean as _var_mean, native_group_norm as _native_group_norm
  File "/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/torch/_refs/__init__.py", line 14, in <module>

Traceback (most recent call last):
  File "/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/var/folders/zn/y387lxqx5zx7j4t_3b0v0bg80000gn/T/ipykernel_92986/3705282367.py", line 36, in <module>
    for i,(x,y) in enumerate(data_loader["train"]):
  File "/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 633, in __next__
    data = self._next_data()
  File "/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1328, in _next_data
    idx, data = self._get_data()
  File "/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1294, in _get_data
    success, data = self._try_get_data()
  File "/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1132, in _try_get_data
    data = self._data_queue.get(timeout=timeout)
  File "/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/multiprocessing/queues.py", line 113, in get
    if not self._poll(timeout):
  File "/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/multiprocessing/connection.py", line 262, in poll
    return self._poll(timeout)
  File "/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/multiprocessing/connection.py", line 429, in _poll
    r = wait([self], timeout)
  File "/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/multiprocessing/connection.py", line 936, in wait
    ready = selector.select(timeout)
  File "/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/selectors.py", line 416, in select
    fd_event_list = self._selector.poll(timeout)
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 2105, in showtraceback
    stb = self.InteractiveTB.structured_traceback(
  File "/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/IPython/core/ultratb.py", line 1396, in structured_traceback
    return FormattedTB.structured_traceback(
  File "/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/IPython/core/ultratb.py", line 1287, in structured_traceback
    return VerboseTB.structured_traceback(
  File "/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/IPython/core/ultratb.py", line 1140, in structured_traceback
    formatted_exception = self.format_exception_as_a_whole(etype, evalue, etb, number_of_lines_of_context,
  File "/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/IPython/core/ultratb.py", line 1030, in format_exception_as_a_whole
    self.get_records(etb, number_of_lines_of_context, tb_offset) if etb else []
  File "/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/IPython/core/ultratb.py", line 1082, in get_records
    style = stack_data.style_with_executing_node(style, self._tb_highlight)
  File "/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/stack_data/core.py", line 455, in style_with_executing_node
    class NewStyle(style):
  File "/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/pygments/style.py", line 97, in __new__
    if styledef == 'noinherit':
  File "/Users/mriki/.pyenv/versions/miniforge3-4.10.3-10/envs/datasci/lib/python3.10/site-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler
    _error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 93040) is killed by signal: Terminated: 15. 
Unexpected exception formatting exception. Falling back to standard exception
_images/23abf61d1758025ac3dcc5e7a2ff265028100c139b1e5d3135770d588ee6633a.png

これでAEの訓練が完了したはずです.もちろんmax_epochsやlearning rateのようなハイパーパラメータの設定をより良いものに変えることで,もっと良いものにできる可能性はあります.興味があれば試してみてください.

訓練過程のモニタリング#

損失関数のモニタリング#


図:訓練中のMSEの変化

損失関数の値の変化を見ると,40エポックを過ぎたあたりで値が収束し始めていることがわかります.validation lossの方は値の増減が多少ありわかりづらいですが,100エポックを超えたあたりで上昇トレンドになりそうな兆しがあり,もう少しで過学習になるかもしれませんね.

[練習課題] 訓練が終わると,辞書historyに訓練中の損失関数の値が記録されているはずです.上に示したグラフ(図:訓練中のMSEの変化)と同様のグラフを作成する関数plot_lossを作成してください.この関数は引数として以下の二つを受け取ります.

  • history :dict

    • 上で作成した辞書.

  • save_path :str | None

    • 作ったグラフの保存先を示すpath.Noneの場合は保存しない.

Hide code cell source
def plot_loss(history, save_path=None):
    fig = plt.figure(figsize=[15,4])
    fig.suptitle("Reconstruction Error")
    fig.supxlabel("Epoch")
    fig.supylabel("MSE")

    ax = fig.add_subplot(1,3,1)
    ax.set_title("training loss and validation loss")
    ax.plot(history["validation_loss"], label="validation_loss", color="red")
    ax.plot(history["train_loss"], label="train_loss", color="blue")
    ax.legend()


    ax2 = fig.add_subplot(1,3,2)
    ax2.set_title("training loss")
    ax2.plot(history["train_loss"], label="train_loss", color="blue")

    ax3 = fig.add_subplot(1,3,3)
    ax3.set_title("validation loss")
    ax3.plot(history["validation_loss"], label="validation_loss", color="red")

    if save_path is not None:
        fig.savefig(save_path)
    plt.close()
    return fig

fig = plot_loss(history)
#display(fig)

再構築画像のモニタリング#

訓練を重ねるごと(epochごと)に,AEがどれだけ上手く画像を再現できるようになったのかを確認してみましょう.上の学習がうまくいった方は,既に表示されているアニメーションを確認してください.また,以下に事前に訓練を行った際に作成したアニメーションを示しました.これらから,AEが徐々に上手く画像の再構成ができるようになっていることがわかると思います.

with open("./figs/ae/autoencoder_mnist.html", "r") as f:
    anim_html = f.read()
    
display(HTML(anim_html))
display(Markdown("アニメーション①:「AEの学習の様子」.(上段:オリジナルの画像,下段:再構築した画像.シークバーを進めるとepochが進みます."))

アニメーション①:「AEの学習の様子」.(上段:オリジナルの画像,下段:再構築した画像.シークバーを進めるとepochが進みます.

上記のアニメーションは以下のコードで作成できます.

Hide code cell source
anim = ArtistAnimation(fig_anime, history["frames"], interval=150)
dt_now = datetime.datetime.now().strftime('%Y年%m月%d日%H時%M分%S秒')
anim.save(f"autoencoder_mnist-{dt_now}.gif") # 今回は実験終了時間をファイル名に含めることにします.
with open(f"autoencoder_mnist-{dt_now}.html", "w") as f:
    f.write(anim.to_jshtml())

display(HTML(anim.to_jshtml()))

潜在空間の可視化#

今回作ったAEでは,Encoder側出力は入力された画像を20次元の潜在的空間の点として表現しています.潜在空間上でどのように画像が点として描かれるのかを見るには,20次元だと可視化することが難しいので,どうにかして2次元の潜在空間上の点になるようにする必要があります.しかしながら,LATENT_DIM=2に設定して再訓練してもうまく行きません.

例として,LATENT_DIM=2として訓練したAEを使い,テストデータを散布図として表示してみることにします.

テストデータの画像はクラスごとに何かしらの特徴を持っているはずなので,潜在空間の中でもある程度はクラスごとにグルーピングされていると嬉しいはずです.しかしながら上の散布図は全くの不規則に散らばっているように見えます.これは潜在空間の次元数が2だけだと,AEの学習がうまくいっていないことを示しています.

学習がうまくいっていないことを確認するために,画像の再構成結果を示します.

20次元の時とは異なり,まったく再現できていないことが分かりました.AEで直接2次元に圧縮するのは(このままだと)難しいようです.

そこでここでは,可視化のための時限削減アルゴリズムであるt-SNE[van der Maaten and Hinton, 2008]を利用して,以下の条件で作成した画像を見比べることで「潜在空間がどれだけ重要な情報を残せているのか」をチェックしてみます.

  1. 画像をそのままt-SNEで2次元に圧縮して作成した散布図

  2. 画像をAEで2次元に圧縮した後に,t-SNEで2次元に圧縮して作成した散布図

  3. 画像をAEで20次元に圧縮した後に,t-SNEで2次元に圧縮して作成した散布図

まずは1の画像を見てみましょう.

次に2の画像を見てみましょう.

最後に3の画像を見てみましょう.

上の図を見る限り:

  • (1)TSNEで元のデータセットを2次元散布図にした所,クラスごとにグループになって表示されていることがわかります.

  • これに対して(2)AEで2次元にした後にTSNEにかけたものは,クラスごとにまとまった様子はありません.

  • 最後に(3)AEで20次元にした後にTSNEにかけたものは,(1)と同じようにクラスごとにグループになって表示されているようです.

2次元空間はあくまでも相対的なものなので,クラスごとの位置関係が(1)と(3)で変わっていることには目を瞑ってください.ここから分かることは,クラスに関する何かしらの情報を(1)と(3)は保持しているのに対して,(2)はそれが欠落しているということです.AEでは際限なく小さな次元にデータを圧縮できるわけではなく,圧縮率を上げると情報が落ちすぎて無意味になる場合があることを覚えておきましょう.

参考文献#

BCD+94

L. Bottou, C. Cortes, J.S. Denker, H. Drucker, I. Guyon, L.D. Jackel, Y. LeCun, U.A. Muller, E. Sackinger, P. Simard, and V. Vapnik. Comparison of classifier methods: a case study in handwritten digit recognition. In Proceedings of the 12th IAPR International Conference on Pattern Recognition, Vol. 3 - Conference C: Signal Processing (Cat. No.94CH3440-5), volume 2, 77–82 vol.2. 1994. doi:10.1109/ICPR.1994.576879.

HZRS15

Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Delving deep into rectifiers: surpassing human-level performance on ImageNet classification. In 2015 IEEE International Conference on Computer Vision (ICCV). IEEE, December 2015. doi:10.1109/iccv.2015.123.

HS06

G E Hinton and R R Salakhutdinov. Reducing the dimensionality of data with neural networks. Science, 313(5786):504–507, July 2006. doi:10.1126/science.1127647.

vdMH08

Laurens van der Maaten and Geoffrey Hinton. Visualizing data using t-SNE. J. Mach. Learn. Res., 9(86):2579–2605, 2008.